"""
    Calculates the clustering based on
"""
import os
import random

import numpy as np
from scipy.signal import normalize
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import pandas as pd

from src.embedding_generators.bert_embeddings import BertEmbedding
from src.knowledge_graphs.wordnet import WordNetDataset
from src.models.cluster.chinesewhispers import MTChineseWhispers
from src.resources.corpus import Corpus
from src.resources.corpus_semcor import CorpusSemCor
from src.sampler.sample_embedding_and_sentences import get_bert_embeddings_and_sentences
from src.utils.create_experiments_folder import randomString

from src.utils.thesaurus_io import print_thesaurus


def predict_clustering(embedding_matrix):

    print("X shape")
    print(embedding_matrix.shape)

    # Normalize matrix
    X = StandardScaler().fit_transform(embedding_matrix)
    pca_model = PCA(n_components=min(20, X.shape[0]), whiten=False)
    X = pca_model.fit_transform(X)

    # These are the best parameters we had determined
    # arguments = {
    #     'std_multiplier': -3.0,
    #     'remove_hub_number': 55,
    #     'min_cluster_size': 1
    # }  # ({'objective': 0.40074227773260607}
    arguments = {
        'std_multiplier': 1.3971661365029329,
        'remove_hub_number': 0,
        'min_cluster_size': 31
    }  # ( {'objective': 0.4569029268755458}

    # {
    #     'std_multiplier': 2.0614460712833473,
    #     'remove_hub_number': 0,
    #     'min_cluster_size': 42
    # } # ({'objective': 0.4336888890925356}, {'objective': {'objective': 1.0887275582334232e-09}}

    cluster_model = MTChineseWhispers(arguments)  # ChineseWhispersClustering(**arguments)

    predicted_labels = cluster_model.fit_predict(X)
    return predicted_labels


if __name__ == "__main__":
    print("Comparing our clusters with other clusters ...")
    print("This time we take into account the true clusters (take it from the other files ..")

    polysemous_words = [
        # ' thought ', ' made ',  # ' was ',
        # ' only ', ' central ', ' pizza '
        ' table ',
        ' bank ',
        ' cold ',
        ' table ',
        ' good ',
        ' mouse ',
        ' was ',
        ' key ',
        ' arms ',
        ' was ',
        ' thought ',
        ' pizza ',
        ' made ',
        ' book '
    ]

    corpus = Corpus()
    corpus_semcor = CorpusSemCor()
    # ALso take the second corpus to check if th
    lang_model = BertEmbedding(corpus=corpus_semcor)
    wordnet_model = WordNetDataset()

    savepath = randomString()

    for tgt_word in polysemous_words:
        print("Looking at word", tgt_word)

        number_of_senses = wordnet_model.get_number_of_senses("".join(tgt_word.split()))

        print("Getting embeddings from BERT")
        tuples_semcor, true_cluster_labels_semcor, _  = get_bert_embeddings_and_sentences(model=lang_model, corpus=corpus_semcor, tgt_word=tgt_word)
        tuples, _, _  = get_bert_embeddings_and_sentences(model=lang_model, corpus=corpus, tgt_word=tgt_word)

        print("semcor tuples and normal tuples are")

        print(tuples_semcor)
        print(len(tuples_semcor))

        print(tuples)
        print(len(tuples))

        # Predict the clustering for the combined corpus ...
        X = np.concatenate(
            [x[1].reshape(1, -1) for x in (tuples_semcor + tuples)], axis=0
        )
        sentences = [
            x[0] for x in (tuples_semcor + tuples)
        ]

        # Labels also should be a python list
        labels = predict_clustering(
            X
        ).tolist()

        print("Printing items ...")

        print(len(X), len(labels), len(sentences))

        assert len(X) == len(labels), (
            len(X), len(labels)
        )
        assert len(sentences) == len(labels), (
            len(sentences), len(labels)
        )

        print_thesaurus(
            sentences=sentences,
            clusters=labels,
            true_clusters=None,
            word=tgt_word,
            savepath=savepath
        )


